-
Notifications
You must be signed in to change notification settings - Fork 162
Upgrade to ONNX 1.19.0 #289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughAdds batch-size plumbing to the ONNX PTQ example; introduces FP4 dtype support and refactors FP8/FP4 casting and initializer creation in ONNX QDQ utilities; expands Torch ONNX conversion trigger to include INT4; updates ONNX extras and example requirements; adds ONNX-version-based test gating and adapts unit/gpu tests and unit test data/layouts. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant CLI as torch_quant_to_onnx.py
participant Model as Torch Model
participant Loader as Calibration DataLoader
participant Export as ONNX Export
User->>CLI: run with --batch_size N
CLI->>CLI: get_model_input_shape(model, N)
CLI->>Loader: load_calibration_data(model, data_size, N, device)
Loader-->>CLI: DataLoader (batch=N)
alt calibration enabled
CLI->>Model: calibrate using Loader
end
CLI->>Export: export ONNX with input shape [N,...]
Export-->>User: ONNX model (batched)
sequenceDiagram
autonumber
participant Graph as qdq_utils
participant Weights as Weight Tensor
participant CastRoutine as _cast_fp4/_cast_fp8
participant Init as onnx.helper.make_tensor
Graph->>Weights: select quantizable weights
alt FP4 target
Weights->>CastRoutine: _cast_fp4 (pack 2×4-bit → uint8)
CastRoutine-->>Init: raw_data (packed uint8), dtype=Float4
Init->>Graph: create FP4 initializer
else FP8 target
Weights->>CastRoutine: _cast_fp8 (flat uint8)
CastRoutine-->>Init: raw_data (uint8), dtype=Float8
Init->>Graph: create FP8 initializer
end
Graph->>Graph: detect/remove pre-quant-scale Casts
Graph->>Graph: rewrite QDQ→DQ / FP4QDQ→2DQ nodes and value_info
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Pre-merge checks (2 passed, 1 warning)❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Poem
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (8)
modelopt/onnx/quantization/qdq_utils.py (1)
928-930
: Use op_type-based detection for Constant nodes
Replace the substring match on"Constant"
inreshape_node.input
with an explicit check of the producer node’sop_type
to avoid relying on name patterns. For example:- shape_constant_name = next(i for i in reshape_node.input if "Constant" in i) + shape_constant_name = next( + i for i in reshape_node.input + if tensor_producer_map[i].op_type == "Constant" + )tests/unit/onnx/test_qdq_utils.py (4)
70-76
: Redundant Cast between Reshape and Transpose (optional)
dq_output
is already FLOAT; insertingweight_cast
to FLOAT is a no-op unless it’s intentionally exercising cast-conversion logic. If this Cast is only for testing that path, consider a brief comment or renaming the node to indicate intent; otherwise, remove it to keep the pattern minimal (DQ → Reshape → Transpose).Also applies to: 80-80
96-96
: Avoid seeding an unused scale initializer whenconstant_scale=True
When
constant_scale=True
, the graph still includes ascale
initializer up-front, which can make the “new scale initializer” assertion trivially pass. Suggest emitting it conditionally so the test precisely validates the pass rewires the scale.Apply this diff:
- nodes = [dq_node, reshape_constant, reshape_node, cast_node, transpose_node, matmul_node] + nodes = [dq_node, reshape_constant, reshape_node, cast_node, transpose_node, matmul_node] @@ - initializer=[weight_tensor, scale_tensor], + initializer=[weight_tensor] if constant_scale else [weight_tensor, scale_tensor],Also applies to: 104-104
252-257
: Name-based Cast preservation: tighten the match (nit)
if "norm/Cast" in node.name
may over-match unrelated nodes. Consider anchoring (e.g., regex(^|/)layer_norm/Cast$
) to avoid false positives as models grow.
315-316
: FP4 tests look consistent with packed 4-bit along axis 0; add a negative/edge testThe 2D inputs and uint8 expectations align with
_cast_fp4
packing the first dimension by 2. Add a failure case for odd first-dimension to lock the contract, and optionally a 4D case to exercise batching.Example additions:
def test_cast_fp4_odd_first_dim_raises(): with pytest.raises(AssertionError): _cast_fp4(np.zeros((3, 2), dtype=np.float32)) def test_cast_fp4_4d_batch(): x = np.random.randn(2, 2, 2, 2).astype(np.float32) # first dim even y = _cast_fp4(x) assert y.dtype == np.uint8 assert y.shape[0] == 1 and y.shape[1:] == x.shape[1:]Also applies to: 320-322, 325-327, 330-332, 335-337, 340-342, 348-348
examples/onnx_ptq/torch_quant_to_onnx.py (3)
86-92
: Avoid heavy model instantiation just to read input size (optional)Creating a full pretrained model here can be slow. If available in your timm version, prefer fetching the default cfg to get
input_size
without instantiating weights; otherwise, at least considerpretrained=False
for this helper.
122-127
: Add basic validation for--batch_size
(and align with data size) (optional)Help text implies constraints but they’re not enforced. Guard against invalid values and optionally round
calibration_data_size
to full batches for determinism.Example (place after
args = parser.parse_args()
):if args.batch_size <= 0: raise ValueError("--batch_size must be > 0") if args.calibration_data_size <= 0: raise ValueError("--calibration_data_size must be > 0") # Optional: ensure full batches only # args.calibration_data_size = (args.calibration_data_size // args.batch_size) * args.batch_size
141-149
: DataLoader returns GPU tensors from worker processes (risk of pickling/IPC issues)
load_calibration_data
constructs CUDA tensors and then usesnum_workers=4
. Multiprocessing with GPU tensors can be brittle and memory-heavy. Prefer keeping tensors on CPU in the loader and moving them to device in the forward loop, or setnum_workers=0
for simplicity.Suggested adjustments:
- Keep calibration tensors on CPU in
load_calibration_data
, and enablepin_memory=True
:return torch.utils.data.DataLoader( calib_tensor, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True )
- Move batches to device inside
forward_loop
:def forward_loop(model): for batch in data_loader: if isinstance(batch, torch.Tensor): batch = batch.to(device, non_blocking=True) model(batch)
- Optionally set
shuffle=False
for reproducible calibration.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
examples/onnx_ptq/torch_quant_to_onnx.py
(2 hunks)modelopt/onnx/quantization/qdq_utils.py
(11 hunks)modelopt/torch/_deploy/utils/torch_onnx.py
(1 hunks)setup.py
(1 hunks)tests/unit/onnx/test_qdq_utils.py
(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/unit/onnx/test_qdq_utils.py (1)
modelopt/onnx/quantization/qdq_utils.py (1)
_cast_fp4
(614-626)
modelopt/onnx/quantization/qdq_utils.py (1)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (1)
NVFP4QTensor
(31-295)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (11)
setup.py (1)
50-51
: LGTM! Correct version upgrades for ONNX dependencies.The ONNX version upgrades align with the PR objectives:
onnx~=1.19.0
: Enables FP4 dtype support (FLOAT4E2M1)onnxconverter-common~=1.16.0
: Ensures compatibility with the new ONNX versionmodelopt/onnx/quantization/qdq_utils.py (8)
49-57
: LGTM! Proper addition of FP4 dtype mapping.The addition of
"Float4": onnx.TensorProto.FLOAT4E2M1
correctly extends the dtype map to support FP4 quantization, which is one of the key features in this ONNX upgrade.
604-612
: Critical fix: FP8 casting now returns correct flat uint8 array.The change from structured dtype to flat uint8 array aligns with ONNX's expected data format for FP8 tensors. This is a necessary fix for proper FP8 quantization.
693-694
: LGTM! Consistent use of dtype map for FP8 checking.The change to use
onnx_dtype_map["Float8"]
maintains consistency with the updated dtype handling throughout the codebase.
951-956
: LGTM! Proper graph simplification by removing unnecessary Cast nodes.The removal of Cast nodes between Reshape and Transpose operations optimizes the graph structure for INT4 quantization.
1007-1021
: Good optimization: Removing redundant Cast after pre-quant scale.The helper function
is_pre_quant_scale_node
correctly identifies and removes unnecessary Cast nodes following pre-quantization scale operations, improving graph efficiency.
1131-1138
: LGTM! Proper FP8 tensor creation using ONNX helper.The use of
onnx.helper.make_tensor
with explicit Float8 dtype and raw bytes ensures correct FP8 weight representation in the ONNX graph.
1219-1236
: LGTM! Correct FP4 tensor creation with proper dimensions.The FP4 tensor creation correctly:
- Uses
onnx_dtype_map["Float4"]
for the data type- Adjusts dimensions to account for packing (2x values in first dim)
- Uses raw bytes for efficient storage
595-596
: Float8 dtype check is backward compatible
onnx_dtype_map["Float8"] still resolves to the original constant (onnx.TensorProto.FLOAT8E4M3FN
), and no other FLOAT8 variants or direct dtype checks were replaced—existing models remain supported.modelopt/torch/_deploy/utils/torch_onnx.py (1)
488-489
: Verify INT4 quantization detection and add tests
Ensureis_int4_quantized(model)
is implemented and correctly flags INT4‐quantized models, and add unit tests covering the INT4→FP16 conversion path in the ONNX utility tests.tests/unit/onnx/test_qdq_utils.py (1)
55-61
: Good change: make Reshape shape a Constant inputUsing a Constant for the reshape shape is cleaner and avoids managing a dedicated initializer. Wiring it into Reshape looks correct.
Also applies to: 65-65
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #289 +/- ##
==========================================
- Coverage 73.94% 73.87% -0.07%
==========================================
Files 172 172
Lines 17405 17438 +33
==========================================
+ Hits 12870 12883 +13
- Misses 4535 4555 +20 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix the rabbit's comment
d792d4f
to
c06fcac
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (5)
examples/onnx_ptq/torch_quant_to_onnx.py (1)
62-67
: Don’t move dataset tensors to CUDA before DataLoader; move per-batch inside the loopCreating a DataLoader over GPU tensors with num_workers>0 can hang/fail (IPC/pickling). Keep data on CPU, enable pin_memory, and move to device inside the forward loop.
- calib_tensor = [t.to(device) for t in calib_tensor] - return torch.utils.data.DataLoader( - calib_tensor, batch_size=batch_size, shuffle=True, num_workers=4 - ) + return torch.utils.data.DataLoader( + calib_tensor, + batch_size=batch_size, + shuffle=True, + num_workers=4, + pin_memory=(device.type == "cuda"), + )And in quantize_model’s forward pass:
- def forward_loop(model): - for batch in data_loader: - model(batch) + def forward_loop(model): + device = next(model.parameters()).device + with torch.inference_mode(): + for batch in data_loader: + if isinstance(batch, (list, tuple)): + batch = torch.stack(batch) # safety if collation returns list + batch = batch.to(device, non_blocking=True) + model(batch)modelopt/onnx/quantization/qdq_utils.py (4)
629-633
: Create FP8 initializer with Float8 dtype and raw bytesCurrent code produces a UINT8 tensor. Use make_tensor with Float8 and raw bytes to preserve dtype.
-def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto: - """Create a FLOAT8E4M3FN tensor directly from numpy array.""" - fp8_data = _cast_fp8(scaled) - return onnx.numpy_helper.from_array(fp8_data, weight_name) +def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto: + """Create a FLOAT8E4M3FN tensor with correct dtype and raw bytes.""" + fp8_bytes = _cast_fp8(scaled).tobytes() + return onnx.helper.make_tensor( + name=weight_name, + data_type=onnx_dtype_map["Float8"], + dims=list(scaled.shape), + vals=fp8_bytes, + raw=True, + )
595-601
: Bug: dtype check compares NumPy dtype to ONNX enum
zp_array.dtype == onnx_dtype_map["Float8"]
will never be true. You lost the ONNX dtype when converting to NumPy. Propagate the TensorProto dtype and branch on that.- if zp_array.dtype == onnx_dtype_map["Float8"]: + if zp_dtype == onnx_dtype_map["Float8"]: scaled = np.asarray(weight_array / scale_array) + zp_array else: scaled = np.asarray((weight_array / scale_array).round()) np.clip(scaled + zp_array, -128, 127, out=scaled)Outside this hunk, update helpers to return/accept dtypes:
# In _get_scale_and_zp(...): # return both arrays and their ONNX data_type enums def _get_scale_and_zp(...)-> tuple[np.ndarray, np.ndarray, int, int]: ... scale_dtype = scale.data_type zp_dtype = zp.data_type scale_array = onnx.numpy_helper.to_array(scale) zp_array = onnx.numpy_helper.to_array(zp) return scale_array, zp_array, scale_dtype, zp_dtype # Update _convert_weight signature to accept zp_dtype (int ONNX enum) def _convert_weight(..., zp_dtype: int, ...) -> np.ndarray: ... # And pass it from qdq_to_dq: scale_array, zp_array, _, zp_dtype = _get_scale_and_zp(...) scaled = _convert_weight(weight_array, scale_array, zp_array, quantized_node, zp_dtype)
693-699
: qDQ path must use ONNX dtype of zero-point, not NumPy dtypeFollow-on to the previous fix: branch on
zp_dtype
.- if zp_array.dtype == onnx_dtype_map["Float8"]: + if zp_dtype == onnx_dtype_map["Float8"]: new_weight = _create_fp8_tensor(scaled, weight_name) logger.debug(f"Converted {weight_name} to FP8") else: new_weight = onnx.numpy_helper.from_array(scaled.astype("int8"), weight_name) logger.debug(f"Converted {weight_name} to INT8")
604-611
: Use explicit FLOAT8E4M3FN when creating the FP8 tensor
Replace the call toonnx.numpy_helper.from_array(fp8_data, weight_name)
in_create_fp8_tensor
with an explicitmake_tensor
invocation that setsdata_type=onnx.TensorProto.FLOAT8E4M3FN
, passes the raw bytes, and specifies the correct shape. For example:import onnx def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto: fp8_data = _cast_fp8(scaled) return onnx.helper.make_tensor( name=weight_name, data_type=onnx.TensorProto.FLOAT8E4M3FN, dims=fp8_data.shape, vals=fp8_data.tobytes(), raw=True, )This ensures the initializer’s
data_type
is FLOAT8E4M3FN rather than UINT8.
♻️ Duplicate comments (1)
modelopt/onnx/quantization/qdq_utils.py (1)
614-627
: FP4 packing along wrong axis; assertion is on axis 0 instead of last axis; docstring missing requirementFP4 should be packed along the last dimension (matches NVFP4QTensor.quantize). Asserting and reshaping on dim 0 will break many weight shapes. Also document the even-length requirement.
-def _cast_fp4(array: np.ndarray) -> np.ndarray: - """Cast a numpy array to FLOAT4E2M1 using PyTorch.""" - array_f32_t = torch.from_numpy(array) - array_f32_t_shape = array_f32_t.shape - assert array_f32_t_shape[0] % 2 == 0, "array_f32_t_shape[0] must be divisible by 2" - array_f4_t_shape = (array_f32_t_shape[0] // 2, *array_f32_t_shape[1:]) - if torch.cuda.is_available(): - array_f32_t = array_f32_t.cuda() - array_f4_t = NVFP4QTensor._cast_fp4(array_f32_t) - array_f4_t = array_f4_t.flatten() - array_f4_t_packed = (array_f4_t[::2] | (array_f4_t[1::2] << 4)).reshape(array_f4_t_shape) - array_f4 = array_f4_t_packed.cpu().numpy().astype(np.uint8) - return array_f4 +def _cast_fp4(array: np.ndarray) -> np.ndarray: + """Cast a numpy array to FLOAT4E2M1 using PyTorch. + + Note: The last dimension must be even; two FP4 values are packed per byte. + """ + array_f32_t = torch.from_numpy(array) + if array_f32_t.shape[-1] % 2 != 0: + raise ValueError( + f"Last dimension must be divisible by 2 for FP4 packing; got {array_f32_t.shape[-1]}" + ) + if torch.cuda.is_available(): + array_f32_t = array_f32_t.cuda() + q4 = NVFP4QTensor._cast_fp4(array_f32_t) # values in [0..15], same shape as input + packed = (q4[..., 1::2] << 4) | q4[..., 0::2] # pack along last dim + return packed.cpu().numpy().astype(np.uint8) # shape: (*, last_dim//2)
🧹 Nitpick comments (8)
setup.py (1)
50-51
: ONNX/ORT version alignment checkBumping to onnx~=1.19.0 and onnxconverter-common~=1.16.0 looks fine. Please verify ORT packages still resolve and support new dtypes (Float4/Float8) used elsewhere. Also consider aligning onnxruntime-directml to ~=1.22.0 for consistency unless there’s a known constraint.
examples/onnx_ptq/torch_quant_to_onnx.py (3)
86-92
: Input shape: validate batch_size > 0Return shape logic is good. Add a guard for batch_size >= 1 to avoid silent mis-shapes.
def get_model_input_shape(model_name, batch_size): """Get the input shape from timm model configuration.""" model = timm.create_model(model_name, pretrained=True, num_classes=1000) data_config = timm.data.resolve_model_data_config(model) input_size = data_config["input_size"] - return (batch_size, *tuple(input_size)) # Add batch dimension + if batch_size < 1: + raise ValueError(f"batch_size must be >= 1, got {batch_size}") + return (batch_size, *tuple(input_size)) # Add batch dimension
122-127
: CLI: constrain batch_sizeAdd argparse-level validation and clarify help.
parser.add_argument( "--batch_size", - type=int, - default=1, - help="Batch size for calibration.", + type=int, + default=1, + choices=range(1, 1024), + metavar="{1..1023}", + help="Batch size for calibration (>=1).", )
131-150
: Minor: set eval/inference during quantizationSet model.eval() before quantize and rely on inference_mode in forward_loop (as suggested above) to speed up and stabilize calibration behavior.
- # Quantize model - quantized_model = quantize_model(model, config, data_loader) + model.eval() + quantized_model = quantize_model(model, config, data_loader)modelopt/onnx/quantization/qdq_utils.py (4)
951-956
: Assumes Cast always follows ReshapeGuard against missing/extra nodes to avoid crashes; gracefully skip if pattern doesn’t match.
- # Remove unnecessary Cast node - cast_node = reshape_child_nodes[0] - assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}" - nodes_to_remove.append(cast_node.name) - cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input] + # Remove unnecessary Cast node (if present) + cast_child_nodes = reshape_child_nodes + if reshape_child_nodes and reshape_child_nodes[0].op_type == "Cast": + cast_node = reshape_child_nodes[0] + nodes_to_remove.append(cast_node.name) + cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input]
958-977
: Be robust if there is no Transpose after CastAvoid indexing without checks.
- if cast_child_nodes[0].op_type == "Transpose": - transpose_node = cast_child_nodes[0] + if cast_child_nodes and cast_child_nodes[0].op_type == "Transpose": + transpose_node = cast_child_nodes[0] nodes_to_remove.append(transpose_node.name) ... - else: - matmul_node = cast_child_nodes[0] + else: + assert cast_child_nodes, f"No consumer found after Cast/Reshape for {node.name}" + matmul_node = cast_child_nodes[0]
1007-1021
: Pre-quant scale cleanup assumes exactly one Cast childLoosen the assumptions: skip if multiple/no children or child not a Cast; only rewire when safe.
- for node in graph.node: - if is_pre_quant_scale_node(node): - pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input] - assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}" - cast_node = pqs_child_nodes[0] - assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}" - node.output.clear() - node.output.extend(cast_node.output) - nodes_to_remove.append(cast_node.name) + for node in graph.node: + if is_pre_quant_scale_node(node): + pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input] + if len(pqs_child_nodes) != 1 or pqs_child_nodes[0].op_type != "Cast": + continue + cast_node = pqs_child_nodes[0] + node.output[:] = cast_node.output + nodes_to_remove.append(cast_node.name)
1335-1343
: Avoid magic number for BF16 detectionUse the enum for readability and safety.
- for initializer in graph.initializer: - if initializer.data_type == 16: + for initializer in graph.initializer: + if initializer.data_type == onnx.TensorProto.BFLOAT16: precision_dtype = "BFloat16" break
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
examples/onnx_ptq/torch_quant_to_onnx.py
(2 hunks)modelopt/onnx/quantization/qdq_utils.py
(11 hunks)modelopt/torch/_deploy/utils/torch_onnx.py
(1 hunks)setup.py
(1 hunks)tests/unit/onnx/test_qdq_utils.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt/torch/_deploy/utils/torch_onnx.py
- tests/unit/onnx/test_qdq_utils.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/qdq_utils.py (1)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (1)
NVFP4QTensor
(31-295)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (3)
modelopt/onnx/quantization/qdq_utils.py (3)
52-54
: Add Float4 mappingThe Float4 → FLOAT4E2M1 mapping is expected for ONNX 1.19+. LGTM.
1131-1137
: MXFP8 initializer creation looks correctUsing make_tensor with Float8 dtype and raw bytes is the right approach. LGTM.
1230-1236
: FP8 scale initializer: consistent approachUsing Float8 dtype + raw bytes is consistent with MXFP8 path. LGTM.
w_f4_proto = onnx.helper.make_tensor( | ||
name=w_f4_name, | ||
data_type=onnx_dtype_map["Float4"], | ||
dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]], | ||
vals=w_f4.tobytes(), | ||
raw=True, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
FP4 initializer dims should reflect packing along the last axis
After fixing _cast_fp4 to pack along the last dim, adjust dims accordingly.
- w_f4_proto = onnx.helper.make_tensor(
- name=w_f4_name,
- data_type=onnx_dtype_map["Float4"],
- dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]],
- vals=w_f4.tobytes(),
- raw=True,
- )
+ w_f4_proto = onnx.helper.make_tensor(
+ name=w_f4_name,
+ data_type=onnx_dtype_map["Float4"],
+ dims=[*w_f4.shape[:-1], w_f4.shape[-1] * 2],
+ vals=w_f4.tobytes(),
+ raw=True,
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
w_f4_proto = onnx.helper.make_tensor( | |
name=w_f4_name, | |
data_type=onnx_dtype_map["Float4"], | |
dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]], | |
vals=w_f4.tobytes(), | |
raw=True, | |
) | |
w_f4_proto = onnx.helper.make_tensor( | |
name=w_f4_name, | |
data_type=onnx_dtype_map["Float4"], | |
dims=[*w_f4.shape[:-1], w_f4.shape[-1] * 2], | |
vals=w_f4.tobytes(), | |
raw=True, | |
) |
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/qdq_utils.py around lines 1219 to 1225, the FP4
initializer currently doubles the first dimension but FP4 packing was changed to
pack along the last axis; update the dims to reflect packing along the last axis
by replacing dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]] with
dims=[*w_f4.shape[:-1], w_f4.shape[-1] * 2] (or equivalent list/tuple
construction) so the last dimension is doubled instead of the first.
Signed-off-by: ajrasane <[email protected]>
Signed-off-by: ajrasane <[email protected]>
Signed-off-by: ajrasane <[email protected]>
82d26ec
to
3420d48
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/onnx/quantization/qdq_utils.py (2)
595-601
: Bug: FP8 detection via NumPy dtype is incorrectComparing zp_array.dtype to an ONNX enum won’t work; NumPy returns uint8 (or a structured uint8), not an ONNX dtype. This branch will never select the FP8 path.
Fix by checking the zero-point TensorProto’s data_type (or pass a boolean). Example patch:
-def _get_scale_and_zp( +def _get_scale_and_zp( node: onnx.NodeProto, initializers: dict[str, onnx.TensorProto], tensor_producers: dict[str, onnx.NodeProto], -) -> tuple[np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray, int]: @@ - return scale_array, zp_array + return scale_array, zp_array, zp.data_type @@ -def _convert_weight( +def _convert_weight( weight_array: np.ndarray, scale_array: np.ndarray, - zp_array: np.ndarray, + zp_array: np.ndarray, + zp_dtype: int, quantized_node: onnx.NodeProto, ) -> np.ndarray: @@ - if zp_array.dtype == onnx_dtype_map["Float8"]: + if zp_dtype == onnx_dtype_map["Float8"]: scaled = np.asarray(weight_array / scale_array) + zp_array else: scaled = np.asarray((weight_array / scale_array).round()) np.clip(scaled + zp_array, -128, 127, out=scaled) @@ - scale_array, zp_array = _get_scale_and_zp(node, initializers, tensor_producers) + scale_array, zp_array, zp_dtype = _get_scale_and_zp(node, initializers, tensor_producers) @@ - scaled = _convert_weight(weight_array, scale_array, zp_array, quantized_node) + scaled = _convert_weight(weight_array, scale_array, zp_array, zp_dtype, quantized_node)
633-637
: Create FP8 initializers with proper ONNX dtype, not numpy_helper.from_arraynumpy_helper.from_array will tag data as UINT8, not Float8. Use make_tensor with data_type Float8 and raw bytes (as done below in quantize_weights_to_mxfp8).
-def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto: - """Create a FLOAT8E4M3FN tensor directly from numpy array.""" - fp8_data = _cast_fp8(scaled) - return onnx.numpy_helper.from_array(fp8_data, weight_name) +def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto: + """Create a FLOAT8E4M3FN initializer with correct dtype.""" + fp8_data = _cast_fp8(scaled) + return onnx.helper.make_tensor( + name=weight_name, + data_type=onnx_dtype_map["Float8"], + dims=[*scaled.shape], + vals=fp8_data.tobytes(), + raw=True, + )
♻️ Duplicate comments (1)
modelopt/onnx/quantization/qdq_utils.py (1)
932-935
: Reshape shape Constant detection is brittle (re-raising prior feedback)String-matching "Constant" in input names is fragile. Use Reshape’s second input and check its producer.
- # Remove constant node from reshape node - shape_constant_name = next(input for input in reshape_node.input if "Constant" in input) - nodes_to_remove.append(tensor_producer_map[shape_constant_name].name) + # Remove Constant that feeds Reshape's shape, if present + if len(reshape_node.input) >= 2: + shape_name = reshape_node.input[1] + shape_producer = tensor_producer_map.get(shape_name) + if shape_producer is not None and shape_producer.op_type == "Constant": + nodes_to_remove.append(shape_producer.name)
🧹 Nitpick comments (2)
tests/unit/onnx/test_qdq_utils.py (1)
315-342
: FP4 tests: packed shape and dtype OK; add an odd-first-dim guard test
- Expecting a packed shape along the first axis (2xN → 1xN) and dtype uint8 matches _cast_fp4’s contract.
- Consider adding a negative test where the first dim is odd to assert the error path.
I can add a parametric test that asserts the raised error for odd first-dimension inputs. Want me to push it?
Also applies to: 348-348
modelopt/onnx/quantization/qdq_utils.py (1)
1223-1229
: Confirm FP4 initializer dims align with packing axisDims double the first axis, consistent with _cast_fp4’s first-axis packing. Verify all consumers assume this convention; earlier feedback suggested last-axis packing—ensure consistency across exporters/importers.
If you plan to switch packing to the last axis later, centralize the “packed axis” in one utility to avoid mismatches.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
examples/onnx_ptq/torch_quant_to_onnx.py
(2 hunks)modelopt/onnx/quantization/qdq_utils.py
(11 hunks)modelopt/torch/_deploy/utils/torch_onnx.py
(1 hunks)setup.py
(1 hunks)tests/unit/onnx/test_qdq_utils.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- examples/onnx_ptq/torch_quant_to_onnx.py
- modelopt/torch/_deploy/utils/torch_onnx.py
- setup.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/onnx/quantization/qdq_utils.py (1)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (1)
NVFP4QTensor
(31-295)
tests/unit/onnx/test_qdq_utils.py (1)
modelopt/onnx/quantization/qdq_utils.py (1)
_cast_fp4
(614-630)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (7)
tests/unit/onnx/test_qdq_utils.py (4)
55-66
: Reshape shape via Constant looks goodUsing a Constant for the reshape shape and wiring it as the second input of Reshape aligns with ONNX patterns and avoids depending on initializers.
70-81
: Explicit Cast between Reshape and Transpose is fineAdding the Cast and feeding Transpose from it mirrors the production graphs these tests target. No issues spotted.
96-96
: Node/initializer lists updated correctlyIncluding the new Constant/Cast nodes in nodes and removing the reshape shape initializer is consistent with the refactor intent.
Also applies to: 104-104
252-257
: Cast skip rule matches implementationTests now only preserve FP32 casts whose names contain "norm/Cast"; all other FP32 casts are converted to FP16. This matches the quantization pass behavior.
modelopt/onnx/quantization/qdq_utils.py (3)
52-54
: ONNX Float4 mapping addedMapping "Float4" to FLOAT4E2M1 is correct for ONNX ≥ 1.19.
If CI uses multiple ONNX versions, please ensure FLOAT4E2M1 exists at runtime to avoid import-time failures.
1135-1141
: Good: FP8 initializer created with correct dtype and raw bytesThis mirrors the recommended pattern for Float8 in ONNX.
610-611
: Return structured uint8 dtype in _cast_fp8
Update_cast_fp8
to return a structured‐dtype view so tests pass and serialization still works:- array_f8 = array_f8_t.cpu().numpy().astype(np.uint8) - return array_f8 + array_f8 = array_f8_t.cpu().numpy() + return array_f8.view(np.dtype((np.uint8, [("e4m3fn", "u1")])))
def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool: | ||
has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input) | ||
return node.op_type == "Mul" and has_pqs_input | ||
|
||
# Remove unnecessay Cast after Pre-quant scale | ||
for node in graph.node: | ||
if is_pre_quant_scale_node(node): | ||
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input] | ||
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}" | ||
cast_node = pqs_child_nodes[0] | ||
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}" | ||
node.output.clear() | ||
node.output.extend(cast_node.output) | ||
nodes_to_remove.append(cast_node.name) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Pre-quant Mul detection is wrong; the naming doesn’t match inserted nodes
The scale input is named with “_awq_scale” and/or produced by a Constant/initializer, not “_pre_quant_scale”. This block won’t trigger.
Make detection structural: Mul with one initializer or Constant input.
- def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
- has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
- return node.op_type == "Mul" and has_pqs_input
+ def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
+ if node.op_type != "Mul" or len(node.input) != 2:
+ return False
+ for inp in node.input:
+ if inp in initializer_map:
+ return True
+ prod = tensor_producer_map.get(inp)
+ if prod is not None and prod.op_type == "Constant":
+ return True
+ return False
Please run the unit test that exercises removal of Cast after pre-quant Mul to confirm this now triggers.
🏁 Script executed:
#!/bin/bash
echo "Occurrences of _awq_scale:"
rg -n "_awq_scale" -g "*.py" || true
echo
echo "Occurrences of pre_quant_scale:"
rg -n "pre_quant_scale" -g "*.py" || true
echo
echo "Definition of is_pre_quant_scale_node and context:"
rg -n -C3 "def is_pre_quant_scale_node" -g "modelopt/onnx/quantization/qdq_utils.py" || true
echo
echo "initializer_map occurrences in qdq_utils.py:"
rg -n "initializer_map" -g "modelopt/onnx/quantization/qdq_utils.py" || true
echo
echo "tensor_producer_map occurrences in qdq_utils.py:"
rg -n "tensor_producer_map" -g "modelopt/onnx/quantization/qdq_utils.py" || true
Length of output: 28239
Use structural detection for pre-quant Mul in qdq_utils.py
The current is_pre_quant_scale_node
looks for “_pre_quant_scale” in the input name, but the inserted scale tensors are named with “_awq_scale”, so this code never matches and the Cast isn’t removed. Replace it with a structural check—i.e. a Mul
with exactly two inputs, one coming from an initializer or Constant
node:
- def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
- has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
- return node.op_type == "Mul" and has_pqs_input
+ def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
+ if node.op_type != "Mul" or len(node.input) != 2:
+ return False
+ for inp in node.input:
+ if inp in initializer_map:
+ return True
+ prod = tensor_producer_map.get(inp)
+ if prod is not None and prod.op_type == "Constant":
+ return True
+ return False
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool: | |
has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input) | |
return node.op_type == "Mul" and has_pqs_input | |
# Remove unnecessay Cast after Pre-quant scale | |
for node in graph.node: | |
if is_pre_quant_scale_node(node): | |
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input] | |
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}" | |
cast_node = pqs_child_nodes[0] | |
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}" | |
node.output.clear() | |
node.output.extend(cast_node.output) | |
nodes_to_remove.append(cast_node.name) | |
def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool: | |
# A pre-quantization scale is always a Mul with one constant input | |
if node.op_type != "Mul" or len(node.input) != 2: | |
return False | |
for inp in node.input: | |
# Check if one of the inputs is a graph initializer | |
if inp in initializer_map: | |
return True | |
# Or produced by a Constant node | |
prod = tensor_producer_map.get(inp) | |
if prod is not None and prod.op_type == "Constant": | |
return True | |
return False | |
# Remove unnecessary Cast after Pre-quant scale | |
for node in graph.node: | |
if is_pre_quant_scale_node(node): | |
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input] | |
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}" | |
cast_node = pqs_child_nodes[0] | |
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}" | |
node.output.clear() | |
node.output.extend(cast_node.output) | |
nodes_to_remove.append(cast_node.name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
tests/_test_utils/import_helper.py (1)
80-92
: Allow safe module-level skipping.If this helper is ever used at import time, current skips lack allow_module_level=True. Harmless inside tests, but brittle at module level.
- except importlib.metadata.PackageNotFoundError: - pytest.skip(f"{package_name} is not installed") + except importlib.metadata.PackageNotFoundError: + pytest.skip(f"{package_name} is not installed", allow_module_level=True) - if version.parse(installed_version) < version.parse(required_version): - pytest.skip( - f"{package_name} version {installed_version} is less than required {required_version}" - ) + if version.parse(installed_version) < version.parse(required_version): + pytest.skip( + f"{package_name} version {installed_version} is less than required {required_version}", + allow_module_level=True, + )tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (2)
43-44
: Gate ONNX version at import time to avoid ImportError before skip.Current skip runs inside the test; onnx-dependent imports above may fail first in envs with <1.19 or missing onnx. Move the gate to module level and drop per-test calls.
- skip_if_onnx_version_below_1_19()
Add near the imports (example):
# Top-level gate (place after standard imports) from _test_utils.import_helper import skip_if_onnx_version_below_1_19 skip_if_onnx_version_below_1_19()Note: This pairs with adding allow_module_level=True in the helper as suggested.
119-120
: Same as above: move version gate to module level.Prevents import-time failures and keeps skip reason centralized.
- skip_if_onnx_version_below_1_19()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/windows/onnx_ptq/genai_llm/requirements.txt
(1 hunks)tests/_test_utils/import_helper.py
(2 hunks)tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py
(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)
tests/_test_utils/import_helper.py (2)
skip_if_no_libcudnn
(37-43)skip_if_onnx_version_below_1_19
(80-92)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
tests/_test_utils/import_helper.py (1)
16-20
: Ensure test dependency availability.packaging is used at test time. Confirm it’s declared in test/extras to prevent CI env drift.
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)
23-23
: Import looks good.
@@ -1,3 +1,4 @@ | |||
datasets>=2.14.5 | |||
onnx==1.18.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix version mismatch with PR objective (onnx 1.19).
This example pins onnx==1.18.0 while the PR upgrades repo/tooling to 1.19.0 and gates tests on >=1.19. Align to avoid feature/API skew (e.g., FP4/INT4 utilities).
-onnx==1.18.0
+onnx==1.19.0
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
onnx==1.18.0 | |
++ b/examples/windows/onnx_ptq/genai_llm/requirements.txt | |
@@ -1,3 +1,3 @@ | |
onnx==1.19.0 |
🤖 Prompt for AI Agents
In examples/windows/onnx_ptq/genai_llm/requirements.txt around line 2, the file
pins onnx==1.18.0 which mismatches the repo/test expectation of onnx>=1.19.0;
update the requirement to onnx==1.19.0 (or onnx>=1.19.0 if a range is preferred)
so the example aligns with the PR tooling/tests and avoids API/feature skew
(e.g., FP4/INT4 utilities).
03d342d
to
c1ea892
Compare
Signed-off-by: ajrasane <[email protected]>
c1ea892
to
02d70b9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/_test_utils/import_helper.py (1)
16-20
: Imports look fine; optional fallback for older Pythons.If Python <3.8 is still in scope anywhere, add a fallback to
importlib_metadata
backport. Otherwise, keep as-is.Apply if needed:
+try: + import importlib.metadata as importlib_metadata +except ModuleNotFoundError: # pragma: no cover + import importlib_metadata # type: ignoretests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)
42-44
: Consider skipping at collection time with a decorator to reduce duplication.Optional: switch to
@pytest.mark.skipif(...)
(using a small predicate in the helper) so tests are marked-skipped at collection, not runtime. Keeps both tests consistent and avoids executing setup code before the skip.Example approach:
- def test_int4_awq(tmp_path): - skip_if_onnx_version_above_1_18() + @pytest.mark.skipif(_onnx_gt_1_18(), reason="Unsupported for ONNX > 1.18.0") + def test_int4_awq(tmp_path): ... - def test_int4_awq_cuda(tmp_path): - skip_if_onnx_version_above_1_18() + @pytest.mark.skipif(_onnx_gt_1_18(), reason="Unsupported for ONNX > 1.18.0") + def test_int4_awq_cuda(tmp_path):Helper (in import_helper.py):
def _onnx_gt_1_18() -> bool: try: v = importlib.metadata.version("onnx") except importlib.metadata.PackageNotFoundError: return False return version.parse(v) > version.parse("1.18.0")Also applies to: 118-121
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/windows/onnx_ptq/genai_llm/requirements.txt
(1 hunks)tests/_test_utils/import_helper.py
(2 hunks)tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py
(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/windows/onnx_ptq/genai_llm/requirements.txt
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)
tests/_test_utils/import_helper.py (2)
skip_if_no_libcudnn
(37-43)skip_if_onnx_version_above_1_18
(80-92)
🔇 Additional comments (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)
23-23
: LGTM: Brings in the version gate helper where needed.
def skip_if_onnx_version_above_1_18(): | ||
package_name = "onnx" | ||
required_version = "1.18.0" | ||
|
||
try: | ||
installed_version = importlib.metadata.version(package_name) | ||
except importlib.metadata.PackageNotFoundError: | ||
pytest.skip(f"{package_name} is not installed") | ||
|
||
if version.parse(installed_version) > version.parse(required_version): | ||
pytest.skip( | ||
f"{package_name} version {installed_version} is less than required {required_version}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skip reason text is wrong; variable name misleads; add allow_module_level for consistency.
Condition skips when ONNX > 1.18, but the message says “less than required”. Rename to reflect max supported, fix message, and pass allow_module_level=True
.
def skip_if_onnx_version_above_1_18():
package_name = "onnx"
- required_version = "1.18.0"
+ max_supported_version = "1.18.0"
try:
installed_version = importlib.metadata.version(package_name)
except importlib.metadata.PackageNotFoundError:
- pytest.skip(f"{package_name} is not installed")
+ pytest.skip(f"{package_name} is not installed", allow_module_level=True)
- if version.parse(installed_version) > version.parse(required_version):
+ if version.parse(installed_version) > version.parse(max_supported_version):
pytest.skip(
- f"{package_name} version {installed_version} is less than required {required_version}"
+ f"{package_name} version {installed_version} > supported {max_supported_version}; expected <= {max_supported_version}",
+ allow_module_level=True,
)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def skip_if_onnx_version_above_1_18(): | |
package_name = "onnx" | |
required_version = "1.18.0" | |
try: | |
installed_version = importlib.metadata.version(package_name) | |
except importlib.metadata.PackageNotFoundError: | |
pytest.skip(f"{package_name} is not installed") | |
if version.parse(installed_version) > version.parse(required_version): | |
pytest.skip( | |
f"{package_name} version {installed_version} is less than required {required_version}" | |
) | |
def skip_if_onnx_version_above_1_18(): | |
package_name = "onnx" | |
max_supported_version = "1.18.0" | |
try: | |
installed_version = importlib.metadata.version(package_name) | |
except importlib.metadata.PackageNotFoundError: | |
pytest.skip(f"{package_name} is not installed", allow_module_level=True) | |
if version.parse(installed_version) > version.parse(max_supported_version): | |
pytest.skip( | |
f"{package_name} version {installed_version} > supported {max_supported_version}; expected <= {max_supported_version}", | |
allow_module_level=True, | |
) |
🤖 Prompt for AI Agents
In tests/_test_utils/import_helper.py around lines 80 to 92, the helper misnames
the version variable and logs an incorrect skip message and omits
allow_module_level; rename required_version to max_supported_version (or
similar), update the skip message to say the installed ONNX version is greater
than the max supported (include installed_version and max_supported_version),
and call pytest.skip(..., allow_module_level=True) when skipping due to version
being above the supported maximum.
Signed-off-by: ajrasane <[email protected]> Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: ajrasane <[email protected]>
What does this PR do?
Type of change: Version upgrade
Overview:
Testing
All unit and integrations tests pass
Before your PR is "Ready for review"
Summary by CodeRabbit
New Features
Refactor
Chores
Tests